{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "29de864c",
   "metadata": {},
   "source": [
    "# Training a Model\n",
    "\n",
    "In the [last notebook](https://www.pybroker.com/en/latest/notebooks/5.%20Writing%20Indicators.html), we learned how to write stock indicators in **PyBroker**. Indicators are a good starting point for developing a trading strategy. But to create a successful strategy, it is likely that a more sophisticated approach using predictive modeling will be needed.\n",
    "\n",
    "Fortunately, one of the main features of **PyBroker** is the ability to train and backtest machine learning models. These models can utilize indicators as features to make more accurate predictions about market movements. Once trained, these models can be backtested using a popular technique known as [Walkforward Analysis](https://www.youtube.com/watch?v=WBZ_Vv-iMv4), which simulates how a strategy would perform during actual trading.\n",
    "\n",
    "We'll explain Walkforward Analysis more in depth later in this notebook. But first, let's get started with some needed imports!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6699f2df",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pybroker\n",
    "from numba import njit\n",
    "from pybroker import Strategy, StrategyConfig, YFinance"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a2c7fb5",
   "metadata": {},
   "source": [
    "As with [DataSource](https://www.pybroker.com/en/latest/reference/pybroker.data.html#pybroker.data.DataSource) and [Indicator](https://www.pybroker.com/en/latest/reference/pybroker.indicator.html#pybroker.indicator.Indicator) data, **PyBroker** can also cache trained models to disk. You can enable caching for all three by calling [pybroker.enable_caches](https://www.pybroker.com/en/latest/reference/pybroker.cache.html#pybroker.cache.enable_caches):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d6a5ae5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pybroker.enable_caches('walkforward_strategy')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b7a78ba",
   "metadata": {},
   "source": [
    "In [the last notebook](https://www.pybroker.com/en/latest/notebooks/5.%20Writing%20Indicators.html), we implemented an indicator that calculates the close-minus-moving-average (CMMA) using [NumPy](https://www.numpy.org) and [Numba](https://numba.pydata.org/). Here's the code for the CMMA indicator again:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bcbef85f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cmma(bar_data, lookback):\n",
    "\n",
    "    @njit  # Enable Numba JIT.\n",
    "    def vec_cmma(values):\n",
    "        # Initialize the result array.\n",
    "        n = len(values)\n",
    "        out = np.array([np.nan for _ in range(n)])\n",
    "        \n",
    "        # For all bars starting at lookback:\n",
    "        for i in range(lookback, n):\n",
    "            # Calculate the moving average for the lookback.\n",
    "            ma = 0\n",
    "            for j in range(i - lookback, i):\n",
    "                ma += values[j]\n",
    "            ma /= lookback\n",
    "            # Subtract the moving average from value.\n",
    "            out[i] = values[i] - ma\n",
    "        return out\n",
    "    \n",
    "    # Calculate for close prices.\n",
    "    return vec_cmma(bar_data.close)\n",
    "\n",
    "cmma_20 = pybroker.indicator('cmma_20', cmma, lookback=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "216876da",
   "metadata": {},
   "source": [
    "## Train and Backtest\n",
    "\n",
    "Next, we want to build a model that predicts the next day's return using the 20-day CMMA. Using [simple linear regression](https://en.wikipedia.org/wiki/Simple_linear_regression) is a good approach to begin experimenting with. Below we import a [LinearRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html) model from [scikit-learn](https://scikit-learn.org/stable/):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "40a28c9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import r2_score"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "905a3d5c",
   "metadata": {},
   "source": [
    "We create a ```train_slr``` function to train the ```LinearRegression``` model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "93056c20",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_slr(symbol, train_data, test_data):\n",
    "    # Train\n",
    "    # Previous day close prices.\n",
    "    train_prev_close = train_data['close'].shift(1)\n",
    "    # Calculate daily returns.\n",
    "    train_daily_returns = (train_data['close'] - train_prev_close) / train_prev_close\n",
    "    # Predict next day's return.\n",
    "    train_data['pred'] = train_daily_returns.shift(-1)\n",
    "    train_data = train_data.dropna()\n",
    "    # Train the LinearRegession model to predict the next day's return\n",
    "    # given the 20-day CMMA.\n",
    "    X_train = train_data[['cmma_20']]\n",
    "    y_train = train_data[['pred']]\n",
    "    model = LinearRegression()\n",
    "    model.fit(X_train, y_train)\n",
    "    \n",
    "    # Test\n",
    "    test_prev_close = test_data['close'].shift(1)\n",
    "    test_daily_returns = (test_data['close'] - test_prev_close) / test_prev_close\n",
    "    test_data['pred'] = test_daily_returns.shift(-1)\n",
    "    test_data = test_data.dropna()\n",
    "    X_test = test_data[['cmma_20']]\n",
    "    y_test = test_data[['pred']]\n",
    "    # Make predictions from test data.\n",
    "    y_pred = model.predict(X_test)\n",
    "    # Print goodness of fit.\n",
    "    r2 = r2_score(y_test, np.squeeze(y_pred))\n",
    "    print(symbol, f'R^2={r2}')\n",
    "    \n",
    "    # Return the trained model and columns to use as input data.\n",
    "    return model, ['cmma_20']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f1fd74e",
   "metadata": {},
   "source": [
    "The ```train_slr``` function uses the 20-day CMMA as the input feature, or predictor, for the ```LinearRegression``` model. The function then fits the ```LinearRegression``` model to the training data for that stock symbol.\n",
    "\n",
    "After fitting the model, the function uses the testing data to evaluate the model's accuracy, specifically by computing the [R-squared](https://en.wikipedia.org/wiki/Coefficient_of_determination) score. The R-squared score provides a measure of how well the ```LinearRegression``` model fits the testing data.\n",
    "\n",
    "The final output of the ```train_slr``` function is the trained ```LinearRegression``` model specifically for that stock symbol, along with the ```cmma_20``` column, which is to be used as input data when making predictions. **PyBroker** will use this model to predict the next day's return of the stock during the backtest. The ```train_slr``` function will be called for each stock symbol, and the trained models will be used to predict the next day's return for each individual stock.\n",
    "\n",
    "Once the function to train the model has been defined, it needs to be registered with **PyBroker**. This is done by creating a new [ModelSource](https://www.pybroker.com/en/latest/reference/pybroker.model.html#pybroker.model.ModelSource)  instance using the [pybroker.model](https://www.pybroker.com/en/latest/reference/pybroker.model.html#pybroker.model.model) function. The arguments to this function are the name of the model (```'slr'``` in this case), the function that will train the model (```train_slr```), and a list of indicators to use as inputs for the model (in this case, ```cmma_20```)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e46d8c01",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_slr = pybroker.model('slr', train_slr, indicators=[cmma_20])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a37b917",
   "metadata": {},
   "source": [
    "To create a trading strategy that uses the trained model, a new [Strategy](https://www.pybroker.com/en/latest/reference/pybroker.strategy.html#pybroker.strategy.Strategy) object is created using the [YFinance](https://www.pybroker.com/en/latest/reference/pybroker.data.html#pybroker.data.YFinance) data source, and specifying the start and end dates for the backtest period."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6059e0af",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = StrategyConfig(bootstrap_sample_size=100)\n",
    "strategy = Strategy(YFinance(), '3/1/2017', '3/1/2022', config)\n",
    "strategy.add_execution(None, ['NVDA', 'AMD'], models=model_slr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "728da208",
   "metadata": {},
   "source": [
    "The [add_execution](https://www.pybroker.com/en/latest/reference/pybroker.strategy.html#pybroker.strategy.Strategy.add_execution) method is then called on the [Strategy](https://www.pybroker.com/en/latest/reference/pybroker.strategy.html#pybroker.strategy.Strategy) object to specify the details of the trading execution. In this case, a ```None``` value is passed as the first argument, which means that no trading function will be used during the backtest.\n",
    "\n",
    "The last step is to run the backtest by calling the [backtest](https://www.pybroker.com/en/latest/reference/pybroker.strategy.html#pybroker.strategy.Strategy.backtest) method on the ```Strategy``` object, with a ```train_size``` of ```0.5``` to specify that the model should be trained on the first half of the backtest data, and tested on the second half."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f9484f84",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backtesting: 2017-03-01 00:00:00 to 2022-03-01 00:00:00\n",
      "\n",
      "Loading bar data...\n",
      "[*********************100%***********************]  2 of 2 completed\n",
      "Loaded bar data: 0:00:00 \n",
      "\n",
      "Computing indicators...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100% (2 of 2) |##########################| Elapsed Time: 0:00:01 Time:  0:00:01\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Train split: 2017-03-01 00:00:00 to 2019-08-28 00:00:00\n",
      "AMD R^2=-0.006808549721842416\n",
      "NVDA R^2=-0.004416132743176426\n",
      "Finished training models: 0:00:00 \n",
      "\n",
      "Finished backtest: 0:00:01\n"
     ]
    }
   ],
   "source": [
    "strategy.backtest(train_size=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "912b0c7d",
   "metadata": {},
   "source": [
    "## Walkforward Analysis\n",
    "\n",
    "\n",
    "**PyBroker** employs a powerful algorithm known as [Walkforward Analysis](https://www.youtube.com/watch?v=WBZ_Vv-iMv4) to perform backtesting. The algorithm partitions the backtest data into a fixed number of time windows, each containing a train-test split of data.\n",
    "\n",
    "The Walkforward Analysis algorithm then proceeds to \"walk forward\" in time, in the same manner that a trading strategy would be executed in the real world. The model is first trained on the earliest window and then evaluated on the test data in that window.\n",
    "\n",
    "As the algorithm moves forward to evaluate the next window in time, the test data from the previous window is added to the training data. This process continues until all of the time windows are evaluated.\n",
    "\n",
    "![Walkforward Diagram](https://github.com/edtechre/pybroker/blob/master/docs/_static/walkforward.png?raw=true)\n",
    "\n",
    "By using this approach, the Walkforward Analysis algorithm is able to simulate the real-world performance of a trading strategy, and produce more reliable and accurate backtesting results.\n",
    "\n",
    "Let's consider a trading strategy that generates buy and sell signals from the [LinearRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html) model that we trained earlier. The strategy is implemented as the ```hold_long``` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b065eb8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hold_long(ctx):\n",
    "    if not ctx.long_pos():\n",
    "        # Buy if the next bar is predicted to have a positive return:\n",
    "        if ctx.preds('slr')[-1] > 0:\n",
    "            ctx.buy_shares = 100\n",
    "    else:\n",
    "        # Sell if the next bar is predicted to have a negative return:\n",
    "        if ctx.preds('slr')[-1] < 0:\n",
    "            ctx.sell_shares = 100\n",
    "            \n",
    "strategy.clear_executions()\n",
    "strategy.add_execution(hold_long, ['NVDA', 'AMD'], models=model_slr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1263e1ab",
   "metadata": {},
   "source": [
    "The ```hold_long``` function opens a long position when the model predicts a positive return for the next bar, and then closes the position when the model predicts a negative return.\n",
    "\n",
    "The [ctx.preds('slr')](https://www.pybroker.com/en/latest/reference/pybroker.context.html#pybroker.context.ExecContext.preds) method is used to access the predictions made by the ```'slr'``` model for the current stock symbol being executed in the function (NVDA or AMD). The predictions are stored in a [NumPy array](https://numpy.org/doc/stable/reference/generated/numpy.array.html), and the most recent prediction for the current stock symbol is accessed using ```ctx.preds('slr')[-1]```, which is the model's prediction of the next bar's return.\n",
    "\n",
    "Now that we have defined a trading strategy and registered the ```'slr'``` model, we can run the backtest using the Walkforward Analysis algorithm.\n",
    "\n",
    "The backtest is run by calling the [walkforward](https://www.pybroker.com/en/latest/reference/pybroker.strategy.html#pybroker.strategy.Strategy.walkforward) method on the ```Strategy``` object, with the desired number of time windows and train/test split ratio. In this case, we will use 3 time windows, each with a 50/50 train-test split. \n",
    "\n",
    "Additionally, since our ```'slr'``` model makes a prediction for one bar in the future, we need to specify the ```lookahead``` parameter as ```1```. This is necessary to ensure that training data does not leak into the test boundary. The ```lookahead``` parameter should always be set to the number of bars in the future being predicted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c7566d0f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backtesting: 2017-03-01 00:00:00 to 2022-03-01 00:00:00\n",
      "\n",
      "Loaded cached bar data.\n",
      "\n",
      "Loaded cached indicator data.\n",
      "\n",
      "Train split: 2017-03-06 00:00:00 to 2018-06-01 00:00:00\n",
      "AMD R^2=-0.007950114729117885\n",
      "NVDA R^2=-0.04203364470839133\n",
      "Finished training models: 0:00:00 \n",
      "\n",
      "Test split: 2018-06-04 00:00:00 to 2019-08-30 00:00:00\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100% (314 of 314) |######################| Elapsed Time: 0:00:00 Time:  0:00:00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Train split: 2018-06-04 00:00:00 to 2019-08-30 00:00:00\n",
      "AMD R^2=0.0006422677593683757\n",
      "NVDA R^2=-0.023591728578221893\n",
      "Finished training models: 0:00:00 \n",
      "\n",
      "Test split: 2019-09-03 00:00:00 to 2020-11-27 00:00:00\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100% (314 of 314) |######################| Elapsed Time: 0:00:00 Time:  0:00:00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Train split: 2019-09-03 00:00:00 to 2020-11-27 00:00:00\n",
      "AMD R^2=-0.015508227883924253\n",
      "NVDA R^2=-0.4567200095787838\n",
      "Finished training models: 0:00:00 \n",
      "\n",
      "Test split: 2020-11-30 00:00:00 to 2022-02-28 00:00:00\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100% (314 of 314) |######################| Elapsed Time: 0:00:00 Time:  0:00:00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Calculating bootstrap metrics: sample_size=100, samples=10000...\n",
      "Calculated bootstrap metrics: 0:00:03 \n",
      "\n",
      "Finished backtest: 0:00:04\n"
     ]
    }
   ],
   "source": [
    "result = strategy.walkforward(\n",
    "    warmup=20, \n",
    "    windows=3, \n",
    "    train_size=0.5, \n",
    "    lookahead=1, \n",
    "    calc_bootstrap=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8619725",
   "metadata": {},
   "source": [
    "During the backtesting process using the Walkforward Analysis algorithm, the ```'slr'``` model is trained on a given window's training data, and then the ```hold_long``` function runs on the same window's test data.\n",
    "\n",
    "The model is trained on the training data to make predictions about the next day's returns. The ```hold_long``` function then uses these predictions to make buy or sell decisions for the current day's trading session.\n",
    "\n",
    "By evaluating the performance of the trading strategy on the test data for each window, we can see how well the strategy is likely to perform in real-world trading conditions. This process is repeated for each time window in the backtest, using the results to evaluate the overall performance of the trading strategy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2bb849bb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>trade_count</td>\n",
       "      <td>43.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>initial_market_value</td>\n",
       "      <td>100000.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>end_market_value</td>\n",
       "      <td>109831.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>total_pnl</td>\n",
       "      <td>12645.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>unrealized_pnl</td>\n",
       "      <td>-2814.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>total_return_pct</td>\n",
       "      <td>12.645000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>total_profit</td>\n",
       "      <td>20566.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>total_loss</td>\n",
       "      <td>-7921.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>total_fees</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>max_drawdown</td>\n",
       "      <td>-14177.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>max_drawdown_pct</td>\n",
       "      <td>-12.272121</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>win_rate</td>\n",
       "      <td>76.744186</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>loss_rate</td>\n",
       "      <td>23.255814</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>winning_trades</td>\n",
       "      <td>33.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>losing_trades</td>\n",
       "      <td>10.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>avg_pnl</td>\n",
       "      <td>294.069767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>avg_return_pct</td>\n",
       "      <td>5.267674</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>avg_trade_bars</td>\n",
       "      <td>25.488372</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>avg_profit</td>\n",
       "      <td>623.212121</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>avg_profit_pct</td>\n",
       "      <td>9.237576</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>avg_winning_trade_bars</td>\n",
       "      <td>19.151515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>avg_loss</td>\n",
       "      <td>-792.100000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>avg_loss_pct</td>\n",
       "      <td>-7.833000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>avg_losing_trade_bars</td>\n",
       "      <td>46.400000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>largest_win</td>\n",
       "      <td>2715.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>largest_win_pct</td>\n",
       "      <td>9.320000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>largest_win_bars</td>\n",
       "      <td>2.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>largest_loss</td>\n",
       "      <td>-5054.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>largest_loss_pct</td>\n",
       "      <td>-16.140000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>largest_loss_bars</td>\n",
       "      <td>43.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>max_wins</td>\n",
       "      <td>13.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>max_losses</td>\n",
       "      <td>2.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>sharpe</td>\n",
       "      <td>0.023425</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>profit_factor</td>\n",
       "      <td>1.094471</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>ulcer_index</td>\n",
       "      <td>1.177116</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>upi</td>\n",
       "      <td>0.009193</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>equity_r2</td>\n",
       "      <td>0.772082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>std_error</td>\n",
       "      <td>4191.846954</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      name          value\n",
       "0              trade_count      43.000000\n",
       "1     initial_market_value  100000.000000\n",
       "2         end_market_value  109831.000000\n",
       "3                total_pnl   12645.000000\n",
       "4           unrealized_pnl   -2814.000000\n",
       "5         total_return_pct      12.645000\n",
       "6             total_profit   20566.000000\n",
       "7               total_loss   -7921.000000\n",
       "8               total_fees       0.000000\n",
       "9             max_drawdown  -14177.000000\n",
       "10        max_drawdown_pct     -12.272121\n",
       "11                win_rate      76.744186\n",
       "12               loss_rate      23.255814\n",
       "13          winning_trades      33.000000\n",
       "14           losing_trades      10.000000\n",
       "15                 avg_pnl     294.069767\n",
       "16          avg_return_pct       5.267674\n",
       "17          avg_trade_bars      25.488372\n",
       "18              avg_profit     623.212121\n",
       "19          avg_profit_pct       9.237576\n",
       "20  avg_winning_trade_bars      19.151515\n",
       "21                avg_loss    -792.100000\n",
       "22            avg_loss_pct      -7.833000\n",
       "23   avg_losing_trade_bars      46.400000\n",
       "24             largest_win    2715.000000\n",
       "25         largest_win_pct       9.320000\n",
       "26        largest_win_bars       2.000000\n",
       "27            largest_loss   -5054.000000\n",
       "28        largest_loss_pct     -16.140000\n",
       "29       largest_loss_bars      43.000000\n",
       "30                max_wins      13.000000\n",
       "31              max_losses       2.000000\n",
       "32                  sharpe       0.023425\n",
       "33           profit_factor       1.094471\n",
       "34             ulcer_index       1.177116\n",
       "35                     upi       0.009193\n",
       "36               equity_r2       0.772082\n",
       "37               std_error    4191.846954"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.metrics_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8ee26a14",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>lower</th>\n",
       "      <th>upper</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th>conf</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">Profit Factor</th>\n",
       "      <th>97.5%</th>\n",
       "      <td>0.259819</td>\n",
       "      <td>1.296660</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95%</th>\n",
       "      <td>0.303435</td>\n",
       "      <td>1.151299</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90%</th>\n",
       "      <td>0.373167</td>\n",
       "      <td>1.002514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">Sharpe Ratio</th>\n",
       "      <th>97.5%</th>\n",
       "      <td>-0.359565</td>\n",
       "      <td>0.050383</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95%</th>\n",
       "      <td>-0.332180</td>\n",
       "      <td>0.018154</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90%</th>\n",
       "      <td>-0.276757</td>\n",
       "      <td>-0.018004</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                        lower     upper\n",
       "name          conf                     \n",
       "Profit Factor 97.5%  0.259819  1.296660\n",
       "              95%    0.303435  1.151299\n",
       "              90%    0.373167  1.002514\n",
       "Sharpe Ratio  97.5% -0.359565  0.050383\n",
       "              95%   -0.332180  0.018154\n",
       "              90%   -0.276757 -0.018004"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.bootstrap.conf_intervals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5b3b2433",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>amount</th>\n",
       "      <th>percent</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>conf</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>99.9%</th>\n",
       "      <td>-13917.50</td>\n",
       "      <td>-12.190522</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99%</th>\n",
       "      <td>-11058.25</td>\n",
       "      <td>-9.693729</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95%</th>\n",
       "      <td>-8380.25</td>\n",
       "      <td>-7.480589</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90%</th>\n",
       "      <td>-7129.00</td>\n",
       "      <td>-6.403027</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         amount    percent\n",
       "conf                      \n",
       "99.9% -13917.50 -12.190522\n",
       "99%   -11058.25  -9.693729\n",
       "95%    -8380.25  -7.480589\n",
       "90%    -7129.00  -6.403027"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.bootstrap.drawdown_conf"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f1abad3",
   "metadata": {},
   "source": [
    "In summary, we have now completed the process of training and backtesting a linear regression model using **PyBroker**, with the help of Walkforward Analysis. The metrics that we have seen are based on the test data from all of the time windows in the backtest. Although our trading strategy needs to be improved, we have gained a good understanding of how to train and evaluate a model in **PyBroker**.\n",
    "\n",
    "Please keep in mind that before conducting regression analysis, it is important to verify certain assumptions such as [homoscedasticity](https://en.wikipedia.org/wiki/Homoscedasticity_and_heteroscedasticity), normality of residuals, etc. I have not provided the details for these assumptions here for the sake of brevity and recommend that you perform this exercise on your own.\n",
    "\n",
    "We are also not limited to just building linear regression models in **PyBroker**. We can train other model types such as gradient boosted machines, neural networks, or any other architecture that we choose. This flexibility allows us to explore and experiment with various models and approaches to find the best performing model for our specific trading goals.\n",
    "\n",
    "PyBroker also offers customization options, such as the ability to specify an [input_data_fn](https://www.pybroker.com/en/latest/reference/pybroker.model.html#pybroker.model.model) for our model in case we need to customize how its input data is built. This would be required when constructing input for autoregressive models (i.e. ARMA or RNN) that use multiple past values to make predictions. Similarly, we can specify our own [predict_fn](https://www.pybroker.com/en/latest/reference/pybroker.model.html#pybroker.model.model) to customize how predictions are made (by default, the model's ```predict``` function is called).\n",
    "\n",
    "With this knowledge, you can start building and testing your own models and trading strategies in **PyBroker**, and begin exploring the vast possibilities that this framework offers!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
